-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[TPU] add kv cache update kernel #19928
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TPU] add kv cache update kernel #19928
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @yaochengji, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a specialized kernel for updating the KV cache on TPUs, addressing the inefficiency of scatter-update operations on this hardware. By leveraging a Pallas-based kernel, it enables more performant and consecutive writes to the KV cache, which is crucial for optimizing large language model inference on TPU platforms. The changes involve implementing the new kernel, integrating it into the existing attention backend, and updating the data preparation logic to support the kernel's specific input requirements.
Highlights
- TPU KV Cache Optimization: Implemented a new Pallas-based kernel (
_kv_cache_update_kernel
) for efficient KV cache updates on TPUs. This kernel is specifically designed to handle consecutive memory writes, which are significantly more performant than scatter-update operations on TPU architecture, addressing a key performance bottleneck. - Integration with vLLM Attention Backend: Integrated the new KV cache update kernel into the vLLM Pallas attention backend. The previous
kv_cache.index_copy_
method has been replaced with a custom XLA operation (torch.ops.xla.kv_cache_update_op
) that invokes the new Pallas kernel. - Dynamic Slot Mapping for Kernel: Introduced a new
_get_slot_mapping_metadata
function to dynamically compute detailed slice mappings (source and destination indices, and lengths) for KV cache updates. This metadata is crucial for the new kernel to efficiently identify and transfer data segments. - CI and Testing: Added a dedicated kernel test (
test_kv_cache_update_kernel.py
) to the.buildkite
CI pipeline. This ensures the correctness and accuracy of the new KV cache update operation on TPUs by comparing its output against a CPU-based reference implementation.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
The kernel implementation is based on the discussion with @bythew3i and @vanbasten23 . We can observe some performance improvement: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
The pull request introduces a new kernel for updating the KV cache on TPUs, which aims to improve performance by reducing scatter updates. The changes include adding a test for the new kernel, incorporating the kernel into the vLLM workflow, and defining necessary functions and constants. The code appears well-structured and includes necessary documentation. However, some comments could be added to improve readability.
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device()) | ||
torch_xla.sync() | ||
|
||
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we want to do this torch.ops.xla.dynamo_set_buffer_donor_
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because it should be an inplace-update.
scratch_shapes=scratch_shapes, | ||
), | ||
out_shape=out_shape, | ||
input_output_aliases={len(scalar_prefetches) + 1: 0}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this maps kv_cache_hbm_ref to the output so that you don't need to specify the output in "_kv_cache_update_kernel"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, they're just aliases.
kv_cache: torch.Tensor, page_size: int, | ||
block_size: int): | ||
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update | ||
new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice!
slices_end = self.input_batch.num_computed_tokens_cpu[:num_reqs] + \ | ||
num_scheduled_tokens_per_req | ||
local_block_start_idx = slices_start // self.block_size | ||
local_block_end_idx = (slices_end - 1) // self.block_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is there a "-1" here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh nvm, I figured it out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because slices_end
is not included in the range.
global_block_start_idx = np.repeat(global_block_start_idx, block_lens) | ||
slice_arange = np.concatenate([self.arange_np[:n] for n in block_lens]) | ||
global_block_indices = global_block_start_idx + slice_arange | ||
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you know why we only get the first element of input_batch.block_table?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember it's a list of list for more complex uses, but I don't know too much details. Here the logic is basically reusing the old one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM if the new jax imports don't disturb the other backends in CI. Just a few commnets
page_num = 1000 | ||
page_size = 32 | ||
combined_kv_head_num = 16 | ||
head_dim = 128 | ||
kernel_block_size = 16 | ||
padded_num_tokens = 128 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there reasonable edge cases we could test here? Like block size != 16, an odd number of tokens, small kv head num, etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Michael for the great suggestion. I added different head_dim, page_size, and kv_head_num. The decoding, prefill within a page, prefill across pages cases are already in the test.
vllm/v1/worker/tpu_model_runner.py
Outdated
# Block size used for kv cache updating kernel | ||
KV_CACHE_UPDATE_BLOCK_SIZE = 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to specify what dim this is blocking. For instance, why isn't it the same as the kv cache block size? What is the reason for not tuning this higher?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the kernel block size, which is different from the block size in vLLM (A.K.A page size). I set it to 8 because the min_num_seqs is 8, and I'd like to avoid trivial DMA is possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or would it be better to put "kernel" in the name, such as KV_CACHE_KERNEL_BLOCK_SIZE to distinguish it from the block size in vLLM (page size)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed it to NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
already.
vllm/v1/worker/tpu_model_runner.py
Outdated
pagged_num_slices = min(padded_num_slices, num_tokens) | ||
pagged_num_slices = ( | ||
pagged_num_slices + KV_CACHE_UPDATE_BLOCK_SIZE - | ||
1) // KV_CACHE_UPDATE_BLOCK_SIZE * KV_CACHE_UPDATE_BLOCK_SIZE | ||
return pagged_num_slices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pagged_num_slices = min(padded_num_slices, num_tokens) | |
pagged_num_slices = ( | |
pagged_num_slices + KV_CACHE_UPDATE_BLOCK_SIZE - | |
1) // KV_CACHE_UPDATE_BLOCK_SIZE * KV_CACHE_UPDATE_BLOCK_SIZE | |
return pagged_num_slices | |
padded_num_slices = min(padded_num_slices, num_tokens) | |
padded_num_slices = ( | |
padded_num_slices + KV_CACHE_UPDATE_BLOCK_SIZE - | |
1) // KV_CACHE_UPDATE_BLOCK_SIZE * KV_CACHE_UPDATE_BLOCK_SIZE | |
return padded_num_slices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks for catching the typo.
""" A fixed shape of slot_mapping_metadata tensor is required to avoid | ||
recompilation. | ||
""" | ||
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you explain these 2 lines
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
pagged_num_slices = min(padded_num_slices, num_tokens)
Why is pagged_num_slices determined this way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the max possible number of slices is limited by num_reqs and num_tokens.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if I follow. I understand the number of slice can be bounded by num_tokens
. Could we just do pagged_num_slices = num_tokens
?
I'm not sure where 2 * max_num_reqs + num_tokens // page_size
comes from
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's say there're max_num_reqs requests, and each request has R_i
tokens. Also each request has at most 2 + R_i // page_size
pages. The total tokens for all the requests are num_tokens
.
So the page_num = sum(2 + R_i // page_size) <= 2 * max_num_reqs + sum(R_i) // page_size = 2 * max_num_reqs + num_tokens // page_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding why each request has at most 2 + R_i // page_size pages
, we chatted offline. Here we are estimating a upper bound of num_pages for a request. If there are 2 tokens and these 2 tokens are at the page boundary, then it needs 2 pages.
page_size = 32 | ||
combined_kv_head_num = 16 | ||
head_dim = 128 | ||
kernel_block_size = 16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you also test when the kernel need to iterate thru multiple blocks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great suggestion! It's added.
cache for the corresponding slice. | ||
- slice_len (int): The length of the slice. | ||
""" | ||
slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder which test did you run this test this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the e2e accuracy test.
page_size: int = 32, | ||
block_size: int = 8, | ||
): | ||
assert slices.shape[0] % block_size == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you need to check other precondition such as new_kv.shape[1:]==kv_cache.shape[1:]?
Also, is there any constraint for the slice_len?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I added other necessary cheks.
712a0c7
to
f44e38f
Compare
f44e38f
to
2a7d39d
Compare
|
||
def _kv_cache_update_kernel( | ||
# Prefetch | ||
slices_ref, # [num_slices, 3] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider making it to [3, num_slices] or [3 * num_slices] to save smem
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. So there's some kind of padding for smem?
# Output | ||
_, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] | ||
# Scratch | ||
scratch, # [block_size, page_size, num_combined_kv_heads, head_dim] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please rename the block_size
to something more meaningful like "kv_pages_per_blk"? Because block_size is overused, eg., it means page_size in vllm which will be very confused for readers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By reading the code, I think it might be better to rename block_size
to num_slices_per_blk
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks for the awesome suggestion!
# Prefetch | ||
slices_ref, # [num_slices, 3] | ||
# Input | ||
new_kv_hbm_ref, # [tokens, num_combined_kv_heads, head_dim] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: tokens
to num_tokens
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
slices_ref, # [num_slices, 3] | ||
# Input | ||
new_kv_hbm_ref, # [tokens, num_combined_kv_heads, head_dim] | ||
kv_cache_hbm_ref, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Add kv_cache shape in comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
length = slices_ref[offset_i, 2] | ||
async_copy = pltpu.make_async_copy( | ||
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...], | ||
scratch.at[i, pl.ds(0, length), ...], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where do we guarantee that the length is not larger than the page_size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Users of the kernel should guarantee the input is valid.
I added a TODO for the dynamic input check.
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Chengji Yao <[email protected]>
Signed-off-by: Chengji Yao <[email protected]>
Signed-off-by: Chengji Yao <[email protected]>
Signed-off-by: Chengji Yao <[email protected]>
Signed-off-by: Chengji Yao <[email protected]>
Signed-off-by: Chengji Yao <[email protected]>
Signed-off-by: Chengji Yao <[email protected]>
e9d02cc
to
1dab650
Compare
Signed-off-by: Chengji Yao <[email protected]>
Signed-off-by: Chengji Yao <[email protected]>
Purpose
TPU is not good at scatter-update. Here consecutive new kv status will be updated together with the help of the kv cache update kernel.
Test Plan
Kernel test: pytest -s -v tests/v1/tpu/test_kv_cache_update_kernel.py
Accuracy test: pytest -s -v tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine
Test Result
Passed.